{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e0deecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
    "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
    "import tilelang\n",
    "import torch\n",
    "import tilelang.language as T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ca2c56d",
   "metadata": {},
   "source": [
    "# Tilelang Lazy JIT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "156e7370",
   "metadata": {},
   "source": [
    "## Tensor Annotation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b070c109",
   "metadata": {},
   "source": [
    "Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n",
    "\n",
    "函数签名的写法与 triton 相似，但做了大量增强，最主要的增强是允许对 Tensor 的标注：\n",
    "\n",
    "例如，下面的代码用 T.Tensor[[int, int], T.float16] 来标注了一个二维 Tensor\n",
    "1. 它的每个维度都是编译期常量，如果改变，会触发重新编译\n",
    "2. 它的类型必须是 T.float16\n",
    "\n",
    "DType 除了写确定的外，还可以写 Any 或者 None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "60bf8954",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm(\n",
    "    A: T.Tensor[[int, int], T.float16],\n",
    "    B: T.Tensor[[int, int], T.float16],\n",
    "    out_dtype: T.dtype = T.float32,\n",
    "    block_M: int = 128,\n",
    "    block_N: int = 128,\n",
    "    block_K: int = 32\n",
    "):\n",
    "    M, K = A.shape\n",
    "    K, N = B.shape\n",
    "    C = T.empty((M, N), out_dtype)\n",
    "    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
    "        A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
    "        B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
    "        C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n",
    "        T.clear(C_local)\n",
    "        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
    "            T.copy(A[bx * block_M, k * block_K], A_shared)\n",
    "            T.copy(B[k * block_K, by * block_N], B_shared)\n",
    "            T.gemm(A_shared, B_shared, C_local)\n",
    "        T.copy(C_local, C[bx * block_M, by * block_N])\n",
    "    return C"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28f868fe",
   "metadata": {},
   "source": [
    "直接将 Tensor 作为参数调用，即可触发完整的 jit 编译运行流程："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ee13394a",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
    "C = gemm(A, B)\n",
    "\n",
    "# check output is correct\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6705091",
   "metadata": {},
   "source": [
    "更改调用的参数，如果编译器参数发生了变化，会触发重新编译："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d8aab5b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
    "B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n",
    "C = gemm(A, B, block_M=64, block_N=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce6b7391",
   "metadata": {},
   "source": [
    "你也可以手动调用 compile 函数编译 kernel\n",
    "\n",
    "1. `ker.compile` 编译 kernel\n",
    "2. `ker.get_tir` 获取 tir\n",
    "3. `ker.par_compile` 并行编译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f3cf3a2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2025-11-25 17:29:46  [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n"
     ]
    }
   ],
   "source": [
    "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n",
    "C = kernel(A, B)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "921761b5",
   "metadata": {},
   "source": [
    "## More Tensor Annotation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4539e54e",
   "metadata": {},
   "source": [
    "### 用 macro 来分离实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad96ba65",
   "metadata": {},
   "source": [
    "接下来，我们会用各种方式来实现一个简单的 gemm，为了方便，我们先写一个 macro 把 gemm 的主要逻辑写出来："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "171d4fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@T.macro\n",
    "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n",
    "    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
    "        A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
    "        B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
    "        C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n",
    "        T.clear(C_local)\n",
    "        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
    "            T.copy(A[bx * block_M, k * block_K], A_shared)\n",
    "            T.copy(B[k * block_K, by * block_N], B_shared)\n",
    "            T.gemm(A_shared, B_shared, C_local)\n",
    "        T.copy(C_local, C[bx * block_M, by * block_N])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "446a1acd",
   "metadata": {},
   "source": [
    "### 用 T.dyn 标记动态 Shape\n",
    "\n",
    "当某些维度是动态的的时候，可以用 T.dyn 来标记。T.dyn 可以接受一个字符串参数，表示变量的名字"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a38aa95",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm_dyn_K(\n",
    "    A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n",
    "    B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n",
    "):\n",
    "    M, K = A.shape\n",
    "    K, N = B.shape\n",
    "    C = T.empty((M, N), T.float32)\n",
    "    gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n",
    "    return C"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c60fd346",
   "metadata": {},
   "source": [
    "查看 lazy_jit 的函数签名，其中带有后缀`$` 的是不确定的编译期常量，带有 `$dyn` 的是运行时的变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c6992eb4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n",
       " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gemm_dyn_K.func.annot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fe6cfdc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
    "C = gemm_dyn_K(A, B)\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ee97bf7",
   "metadata": {},
   "source": [
    "### 用 T.StridedTensor 标记带 stride 的 Tensor\n",
    "\n",
    "标记方法：T.StridedTensor[Shape, Stride, DType]，每个 Shape 或 Stride 可以写\n",
    "* int: 表示编译期常量\n",
    "* T.dyn：表示运行时常量\n",
    "\n",
    "DType 可以写 None 或 Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9dde1dae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any\n",
    "\n",
    "@tilelang.lazy_jit\n",
    "def as_contingious(\n",
    "    A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
    "):\n",
    "    M, N = A.shape\n",
    "    B = T.empty((M, N), A.dtype)\n",
    "    block_M = 128\n",
    "    block_N = 128\n",
    "    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
    "        T.copy(\n",
    "            A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
    "            B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
    "        )\n",
    "    return B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dec2c0a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(1024, 1024, device='cuda')\n",
    "B = as_contingious(A[::2, ::2])\n",
    "B_ref = A[::2, ::2].contiguous()\n",
    "torch.testing.assert_close(B, B_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5fb20d6",
   "metadata": {},
   "source": [
    "## More Annotation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "890df0a2",
   "metadata": {},
   "source": [
    "### 用 T.ptr 标注 Tensor\n",
    "lazy_jit 允许你用 T.ptr 来声明一个 handle，但必须在函数内用 T.match_buffer 给它定义 shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0fc17af6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm_ptr(\n",
    "    A: T.ptr,\n",
    "    B: T.ptr,\n",
    "    M: int,\n",
    "    N: int,\n",
    "    K: int,\n",
    "):\n",
    "    A = T.match_buffer(A, (M, K), T.float16)\n",
    "    B = T.match_buffer(B, (K, N), T.float16)\n",
    "    C = T.empty((M, N), T.float32)\n",
    "    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
    "    return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8e52a554",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
    "C = gemm_ptr(A, B, 1024, 256, 512)\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b19ef90",
   "metadata": {},
   "source": [
    "### 用 T.int32 标注运行时变量\n",
    "\n",
    "lazy_jit 允许你用 T.int32 或其他类型来定义运行时变量，这样，你可以写一个完全动态的 gemm，这和 triton 非常相似"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c1e7598a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm_ptr_dyn(\n",
    "    A: T.ptr,\n",
    "    B: T.ptr,\n",
    "    M: T.int32,\n",
    "    N: T.int32,\n",
    "    K: T.int32,\n",
    "):\n",
    "    A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n",
    "    B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n",
    "    C = T.empty((M, N), T.float32)\n",
    "    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
    "    return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9e9a4c88",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
    "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39166cb4",
   "metadata": {},
   "source": [
    "## 编译与并行编译"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c6fbe08",
   "metadata": {},
   "source": [
    "lazyjit 和原来的 jit 都支持并行编译\n",
    "\n",
    "为了防止 torch.tensor 白白浪费内存，可以使用 T.Tensor 来创建 placeholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7222e57b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c6d7f05cdfff412e9a527332438f7aa2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Elaborating:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14836065a21b41ae8fc34e8763ae49fc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parallel Compiling:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[<tilelang.jit.kernel.JITKernel at 0x7f29c0072ed0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c00882f0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c00735f0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c0088890>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c01f94c0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c0073fe0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c0070ce0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c00732f0>]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from itertools import product\n",
    "\n",
    "def get_configs():\n",
    "    return [\n",
    "        {\n",
    "            'A': T.Tensor((1024, 1024), T.float32),\n",
    "            'B': T.Tensor((1024, 1024), T.float32),\n",
    "            'block_M': block_M,\n",
    "            'block_N': block_N,\n",
    "            'block_K': block_K,\n",
    "        }\n",
    "        for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
    "    ]\n",
    "\n",
    "gemm.par_compile(get_configs())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5160d2cc",
   "metadata": {},
   "source": [
    "## 更便利的 Macro"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be44afc4",
   "metadata": {},
   "source": [
    "tilelang 的 macro 现在已经升级：\n",
    "\n",
    "1. 允许用 `T.Ref` 作为 annotation，这类似与 C++ 的引用传递\n",
    "2. 允许返回多个值\n",
    "3. 允许嵌套，递归"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79575972",
   "metadata": {},
   "source": [
    "### T.Ref 传递引用\n",
    "\n",
    "T.Ref 传递的引用可以 var 也可以是 Buffer 的索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90eaa6e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# from tvm.script import tir as T\n",
       "\n",
       "@T.prim_func\n",
       "def foo(x_handle: T.handle):\n",
       "    x = T.match_buffer(x_handle, (2,), strides=(1,))\n",
       "    # with T.block(\"root\"):\n",
       "    bx = T.launch_thread(\"blockIdx.x\", 1)\n",
       "    tx = T.launch_thread(\"threadIdx.x\", 128)\n",
       "    ty = T.launch_thread(\"threadIdx.y\", 1)\n",
       "    tz = T.launch_thread(\"threadIdx.z\", 1)\n",
       "    with T.block(\"tilelang_root\"):\n",
       "        T.reads()\n",
       "        idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n",
       "        T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n",
       "        T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n",
       "        idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n",
       "        x[1] = T.float32(1.0)\n",
       "        _tmp: T.int32 = idx[0]\n",
       "        x[_tmp] = T.float32(1.0)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@T.macro\n",
    "def macro_with_ref(x: T.Ref):\n",
    "    x = 1 # noqa: F841\n",
    "\n",
    "@T.prim_func\n",
    "def foo(x: T.Tensor((2,))):\n",
    "    with T.Kernel(1) as _:\n",
    "        # 支持常量 index\n",
    "        macro_with_ref(x[1])\n",
    "\n",
    "        # 也支持变量 index\n",
    "        idx = T.alloc_var(T.int32, 0)\n",
    "        macro_with_ref(x[idx])\n",
    "\n",
    "foo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bb447a2",
   "metadata": {},
   "source": [
    "### 当作参数传递\n",
    "\n",
    "你可以把 macro 当做参数传递"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "dc7bb779",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def element_wise(\n",
    "    A: T.Tensor[[T.dyn], Any],\n",
    "    fn,\n",
    "):\n",
    "    N, = A.shape\n",
    "    B = T.empty((N,), dtype=A.dtype)\n",
    "    block_N = 128\n",
    "    with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
    "        for i in T.Parallel(block_N):\n",
    "            idx = bx * block_N + i\n",
    "            B[idx] = fn(A[idx])\n",
    "    return B\n",
    "@T.macro\n",
    "def add_one(x):\n",
    "    return x + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a89fdb44",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(1024, device='cuda')\n",
    "B = element_wise(A, add_one)\n",
    "B_ref = A + 1\n",
    "torch.testing.assert_close(B, B_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef6e403a",
   "metadata": {},
   "source": [
    "### Macro 递归\n",
    "\n",
    "虽然不知道有没有这种需求，但 macro 是可以递归的，但要求终止条件编译期间确定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "7703cab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "@T.macro\n",
    "def n31(x, var: T.Ref):\n",
    "    if x == 1:\n",
    "        pass\n",
    "    elif x % 2 == 0:\n",
    "        var = var // 2\n",
    "        n31(x // 2, var)\n",
    "    else:\n",
    "        var = var * 3 + 1\n",
    "        n31(x * 3 + 1, var)\n",
    "\n",
    "@tilelang.lazy_jit\n",
    "def foo(A: T.Tensor[[1], T.int32], n: int):\n",
    "    with T.Kernel(1) as _:\n",
    "        n31(n, A[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "542ddd4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([18], device='cuda:0', dtype=torch.int32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A = torch.tensor([100], dtype=torch.int32, device='cuda')\n",
    "foo(A, 5)\n",
    "A"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc30c2d2",
   "metadata": {},
   "source": [
    "### Macro 返回多个值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5a2388f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# from tvm.script import tir as T\n",
       "\n",
       "@T.prim_func\n",
       "def foo():\n",
       "    # with T.block(\"root\"):\n",
       "    x = T.launch_thread(\"blockIdx.x\", 32)\n",
       "    tx = T.launch_thread(\"threadIdx.x\", 128)\n",
       "    ty = T.launch_thread(\"threadIdx.y\", 1)\n",
       "    tz = T.launch_thread(\"threadIdx.z\", 1)\n",
       "    with T.block(\"tilelang_root\"):\n",
       "        T.reads()\n",
       "        T.writes()\n",
       "        s: T.int32 = T.sin(x)\n",
       "        c: T.int32 = T.cos(x)\n",
       "        a: T.int32 = s + c\n",
       "        b: T.int32 = s - c\n",
       "        T.evaluate(0)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@T.macro\n",
    "def sincos(x):\n",
    "    return T.sin(x), T.cos(x)\n",
    "\n",
    "@T.prim_func\n",
    "def foo():\n",
    "    with T.Kernel(32) as x:\n",
    "        s, c = sincos(x)\n",
    "        a = s + c # noqa: F841\n",
    "        b = s - c # noqa: F841\n",
    "foo"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tilelang-dev_0",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
